/* Copyright (c) 2024 Huawei Technologies Co., Ltd.
 * This file is a part of the CANN Open Software.
 * Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 * ===================================================================================================================*/

#include <gtest/gtest.h>
#include <iostream>
#include <string>
#include "nlohmann/json.hpp"
#include "graph/ascend_string.h"
#include "register/tuning_bank_key_registry.h"

namespace tuningtiling {
struct DynamicRnnInputArgsV2 {
  int64_t batch;
  int32_t dims;
};
bool ConvertTilingContext(const gert::TilingContext* context,
                          std::shared_ptr<void> &input_args, size_t &size) {
  if (context == nullptr) {
    auto rnn = std::make_shared<DynamicRnnInputArgsV2>();
    rnn->batch = 0;
    rnn->dims = 1;
    size = sizeof(DynamicRnnInputArgsV2);
    input_args = rnn;
    return false;
  }
  return true;
}

// maybe del
DECLARE_STRUCT_RELATE_WITH_OP(DynamicRNN, DynamicRnnInputArgsV2, batch, dims);
REGISTER_OP_BANK_KEY_CONVERT_FUN(DynamicRNN, ConvertTilingContext);

// new api test
// DECLARE_STRUCT_RELATE_WITH_OP_V2(DynamicRNN, DynamicRnnInputArgsV2,
//   batch, dims);
// REGISTER_OP_BANK_KEY_CONVERT_FUN_V2(DynamicRNN, ConvertTilingContext);
class RegisterOPBankKeyUT : public testing::Test {
 protected:
  void SetUp() {}

  void TearDown() {}
};

extern "C" void _ZN12tuningtiling21OpBankKeyFuncRegistryC1ERKN2ge12AscendStringERKSt8functionIFbRKSt10shared_ptrIvEmRN15ascend_nlohmann10basic_jsonISt3mapSt6vectorSsblmdSaNSA_14adl_serializerESD_IhSaIhEEEEEERKS5_IFbRS7_RmRKSH_EE();

TEST_F(RegisterOPBankKeyUT, convert_tiling_context) {
  _ZN12tuningtiling21OpBankKeyFuncRegistryC1ERKN2ge12AscendStringERKSt8functionIFbRKSt10shared_ptrIvEmRN15ascend_nlohmann10basic_jsonISt3mapSt6vectorSsblmdSaNSA_14adl_serializerESD_IhSaIhEEEEEERKS5_IFbRS7_RmRKSH_EE();
  auto& func = OpBankKeyFuncRegistry::RegisteredOpFuncInfo();
  auto iter = func.find("DynamicRNN");
  nlohmann::json test;
  test["batch"] = 12;
  test["dims"] = 2;
  ASSERT_TRUE(iter != func.cend());

  const OpBankLoadFun& load_func = iter->second.GetBankKeyLoadFunc();
  std::shared_ptr<void> ld = nullptr;
  size_t len = 0;
  EXPECT_TRUE(load_func(ld, len, test));
  EXPECT_TRUE(ld != nullptr);

  const auto &parse_func = iter->second.GetBankKeyParseFunc();
  nlohmann::json test2;
  EXPECT_TRUE(parse_func(ld, len, test2));
  EXPECT_EQ(test, test2);

  const auto &convert_func = iter->second.GetBankKeyConvertFunc();
  std::shared_ptr<void> op_key = nullptr;
  size_t s = 0U;
  EXPECT_FALSE(convert_func(nullptr, op_key, s));
  EXPECT_TRUE(s !=0);
  EXPECT_TRUE(op_key != nullptr);
  auto rnn_ky = std::static_pointer_cast<DynamicRnnInputArgsV2>(op_key);
  EXPECT_EQ(rnn_ky->batch, 0);

}

// TEST_F(RegisterOPBankKeyV2UT, convert_tiling_context) {
//   auto& func = OpBankKeyFuncRegistryV2::RegisteredOpFuncInfoV2();
//   auto iter = func.find("DynamicRNN");
//   nlohmann::json test;
//   test["batch"] = 12;
//   test["dims"] = 2;
//   std::string dump_str;
//   dump_str = test.dump();
//   ge::AscendString test_str;
//   test_str = ge::AscendString(dump_str.c_str());
//   ASSERT_TRUE(iter != func.cend());

//   const OpBankLoadFunV2& load_funcV2 = iter->second.GetBankKeyLoadFuncV2();
//   std::shared_ptr<void> ld = nullptr;
//   size_t len = 0;
//   EXPECT_TRUE(load_funcV2(ld, len, test_str));
//   EXPECT_TRUE(ld != nullptr);

//   const auto &parse_funcV2 = iter->second.GetBankKeyParseFuncV2();
//   ge::AscendString test2;
//   EXPECT_TRUE(parse_funcV2(ld, len, test2));
//   EXPECT_EQ(test_str, test2);

//   const auto &convert_funcV2 = iter->second.GetBankKeyConvertFuncV2();
//   std::shared_ptr<void> op_key = nullptr;
//   size_t s = 0U;
//   EXPECT_FALSE(convert_funcV2(nullptr, op_key, s));
//   EXPECT_TRUE(s !=0);
//   EXPECT_TRUE(op_key != nullptr);
//   auto rnn_ky = std::static_pointer_cast<DynamicRnnInputArgsV2>(op_key);
//   EXPECT_EQ(rnn_ky->batch, 0);

// }
}  // namespace tuningtiling
